import importlib
from os import path as osp


from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import timm


import torch
from timm.models.vision_transformer import VisionTransformer
from einops.layers.torch import Rearrange

from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
import math

from functools import partial
from Temporal_Transformer import Block




class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
        super(ConvLayer, self).__init__()

        bias = False if norm == 'BN' else True
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        if activation is not None:
            self.activation = nn.ReLU()
        else:
            self.activation = None

        self.norm = norm
        if norm == 'BN':
            self.norm_layer = nn.BatchNorm2d(out_channels)
        elif norm == 'IN':
            self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)

    def forward(self, x):
        out = self.conv2d(x)

        if self.norm in ['BN', 'IN']:
            out = self.norm_layer(out)

        if self.activation is not None:
            out = self.activation(out)

        return out
    

    
    
class STP(nn.Module):
    def __init__(self, model, config):
        super().__init__()

        in_chan = config.config
        emb = config.emb
        K = config.kernel
        p = config.padding

        self.config = config
        
        self.model = model
        
        self.temporal_encod = nn.Parameter(torch.zeros(1, config.T + 1, emb[0]))

        self.temporal1 = Block(dim=emb[0], num_heads=4, mlp_ratio=4, qkv_bias=True, qk_scale=None,
                drop=0.0, attn_drop=0.0, drop_path=0.03, norm_layer=partial(nn.LayerNorm, eps=1e-6),
                sr_ratio=1, linear=False)
    
        self.temporal2 = Block(dim=emb[1], num_heads=2, mlp_ratio=4, qkv_bias=True, qk_scale=None,
                drop=0.0, attn_drop=0.0, drop_path=0.06, norm_layer=partial(nn.LayerNorm, eps=1e-6),
                sr_ratio=1, linear=False)

        
        self.Conv1 = ConvLayer(in_chan[0], in_chan[0], kernel_size=5, stride=1, padding=2)
        self.OverlapPatch1 = ConvLayer(in_chan[0], emb[0], kernel_size=k[0], stride=4, padding=p[0], norm='BN')
        
        self.Conv2 = ConvLayer(in_chan[1], in_chan[1], kernel_size=3, stride=1, padding=1)
        self.OverlapPatch2 = ConvLayer(in_chan[1], emb[1], kernel_size=k[1], stride=2, padding=p[1], norm='BN')
        
        self.Conv3 = ConvLayer(in_chan[2], in_chan[2], kernel_size=3, stride=1, padding=1)
        self.OverlapPatch3 = ConvLayer(in_chan[2], emb[2], kernel_size=k[2], stride=2, padding=p[2], norm='BN')
        

        nn.init.trunc_normal_(self.temporal_encod, std=0.02)
        
        for name, param in self.model.named_parameters():
            param.requires_grad = False
    
    def forward_fea(self, x):
        b, c, h, w = x.shape[0]
        x = rearrange(x, 'b t n h w -> (b t) n h w')
        x = self.Conv1(x)
        x = self.OverlapPatch1(x) 
        x = rearrange(x, '(b t) c h w -> (b h w) t c', b = b)
        x = x + self.temporal_encod
        x = self.temporal1(x, self.config.T1, self.config.T2)
        
        x = rearrange(x, '(b h w) t c -> (b t) c h w', b = b, h = h//4)
        x = self.Conv2(x)
        x = self.OverlapPatch2(x)
        x = rearrange(x, '(b t) c h w -> (b h w) t c', b = b)
        x = self.temporal2(x, self.config.T1, self.config.T2)
        
        x = x[:, 0]
        x = rearrange(x, '(b h w) c -> b c h w', b = b, h = h//8)
        x = self.Conv3(x)
        x = self.OverlapPatch3(x)
        
        x = rearrange(x, 'b c h w -> b (h w) c', b = b)
        
        return x
            

    def forward(self, x):
        
        x = self.forward_fea(x)
    
        
        cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.model.pos_drop(x + self.model.pos_embed)
        
        
        x = self.model.blocks(x)
        x = self.model.norm(x)

        x = self.model.head(x[:, 0])
        
        return x



def create_model(opt,logger_name):

    vit = timm.create_model(model_name='vit_small_patch16_224', pretrained=True)
    del vit.patch_embed
    
    model = STP(vit, config)

    
    return model


